#!/usr/bin/env python3

from typing import Literal
import pathlib
from pathlib import Path
from itertools import chain

import numpy as np
import torch
from torch.utils.data import Dataset
from datasets import load_dataset
from transformers import AutoTokenizer, PreTrainedTokenizer


class WikiText103(Dataset):

    def __init__(
        self,
        root: str | pathlib.Path | None,
        split: Literal["train", "validation", "test"],
        tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("gpt2"),
        padding: bool = False,
        truncation: bool = True,
        max_length: int = 1024,
        num_proc: int = 1,
        no_cache: bool = False,
    ) -> None:
        super().__init__()

        # args
        self.root = root
        self.split = split
        self.tokenizer = tokenizer
        self.padding = padding
        self.truncation = truncation
        self.max_length = max_length
        self.num_proc = num_proc
        self.no_cache = no_cache

        self.cache_path = Path(self.root) / f"{self.split}.pt"

        if self.cache_path.exists() and not no_cache:
            print(f"load cache data at {self.cache_path}")
            self.dataset = torch.load(self.cache_path)

        else:
            print("create dataset")
            self.dataset = self.create_dataset()

    def create_dataset(self) -> dict[str, torch.Tensor]:
        dataset = load_dataset(
            "Salesforce/wikitext",
            "wikitext-103-v1",
            cache_dir=self.root,
            split=self.split,
            num_proc=self.num_proc,
        )
        column_names = dataset.column_names
        dtype = np.uint16 if self.tokenizer.vocab_size < 64 * 1024 else np.int32

        def tokenize_concat(examples):
            tokenize = lambda example: self.tokenizer(example["text"])
            # We just need 'input_ids', not 'attention_mask' (since it's all 1)
            input_ids = np.fromiter(
                chain(*tokenize(examples)["input_ids"]), dtype=dtype
            )
            # Need to return a list since we're doing batched processing
            return {"input_ids": [input_ids], "len": [len(input_ids)]}

        dataset = dataset.map(
            tokenize_concat,
            batched=True,
            num_proc=max(self.num_proc, 1),
            remove_columns=column_names,
            desc="Running tokenizer on dataset",
        )

        dataset = torch.tensor(list(chain(*dataset["input_ids"])))

        dataset = dataset[: (len(dataset) // self.max_length) * self.max_length].view(
            -1, self.max_length
        )

        torch.save(dataset, self.cache_path)

        return dataset

    def __getitem__(self, idx: int):

        return self.dataset[idx]

    def __len__(self):
        return self.dataset.size(0)


# if __name__ == "__main__":
#     data = WikiText103("./dataset", split="test")
#     print(len(data))
